1 Starting out with TidyModels

TidyModels is the newer version of Max Kuhn’s CARET and can be used for a number of machine learning tasks. This modelling framework takes a different approach to modelling - allowing for a more structured workflow, and like tidyverse, has a whole set of packages for making the machine learning process easier. I will touch on a number of these packages in the following sub sections.

This package supercedes that in R for Data Science, as Hadley Wickham admitted he needed a better modelling solution at the time, and Max Kuhn and team have delivered on this.

The aim of this webinar is to:

  • Teach you the steps to build an ML model from scratch
  • Work with rsample and recipes for feature engineering
  • Train and build a workflow with Parsnip
  • Evaluate your model with Yardstick and CARET
  • Improve your model with Tune and Dials

The framework of a TidyModels approach flows as so:

I will show you the steps in the following tutorials.

2 Step one - Importing data into the R environment

I will load in the stranded patient data - a stranded patient is a patient that has been in hospital for longer than 7 days and we also call these Long Waiters. The import steps are below and use the native readr package to load this in:

# Read in the data
strand_pat <- read_csv("Data/Stranded_Data.csv") %>% 
  setNames(c("stranded_class", "age", "care_home_ref_flag", "medically_safe_flag", 
             "hcop_flag", "needs_mental_health_support_flag", "previous_care_in_last_12_month", "admit_date", "frail_descrip")) %>% 
  mutate(stranded_class = factor(stranded_class)) %>% 
  drop_na()
## 
## -- Column specification --------------------------------------------------------
## cols(
##   Stranded.label = col_character(),
##   Age = col_double(),
##   Care.home.referral = col_double(),
##   MedicallySafe = col_double(),
##   HCOP = col_double(),
##   Mental_Health_Care = col_double(),
##   Periods_of_previous_care = col_double(),
##   admit_date = col_character(),
##   frailty_index = col_character()
## )
print(head(strand_pat))
## # A tibble: 6 x 9
##   stranded_class   age care_home_ref_flag medically_safe_flag hcop_flag
##   <fct>          <dbl>              <dbl>               <dbl>     <dbl>
## 1 Not Stranded      50                  0                   0         0
## 2 Not Stranded      31                  1                   0         1
## 3 Not Stranded      32                  0                   1         0
## 4 Not Stranded      69                  1                   1         0
## 5 Not Stranded      33                  0                   0         1
## 6 Stranded          75                  1                   1         0
## # ... with 4 more variables: needs_mental_health_support_flag <dbl>,
## #   previous_care_in_last_12_month <dbl>, admit_date <chr>, frail_descrip <chr>

As this is a classification problem we need to look at the classification imbalance in the predictor variable i.e. the thing we are trying to predict.

3 Step Two - Analysing the Class Imbalance

The following code looks at the class imbalance as a volume and proportion and then I am going to use the second index from the class balance table i.e. the number of people who are long waiters is going to be lower than those that aren’t, otherwise we are offering a very poor service to patients.

class_bal_table <- table(strand_pat$stranded_class)
prop_tab <- prop.table(class_bal_table)
upsample_ratio <- class_bal_table[2] / sum(class_bal_table)
print(prop_tab)
## 
## Not Stranded     Stranded 
##    0.6552217    0.3447783
print(class_bal_table)
## 
## Not Stranded     Stranded 
##          458          241
print(upsample_ratio)
##  Stranded 
## 0.3447783

4 Step Three - Observe data structures

It is always a good idea to observe the data structures of the data items we are trying to predict. I generally separate the names of the variables out into factors, integer / numerics and character vectors:

strand_pat$admit_date <- as.Date(strand_pat$admit_date, format="%d/%m/%Y") #Format date to be date to work with recipes steps
factors <- names(select_if(strand_pat, is.factor))
numbers <- names(select_if(strand_pat, is.numeric))
characters <- names(select_if(strand_pat, is.character))
print(factors); print(numbers); print(characters)
## [1] "stranded_class"
## [1] "age"                              "care_home_ref_flag"              
## [3] "medically_safe_flag"              "hcop_flag"                       
## [5] "needs_mental_health_support_flag" "previous_care_in_last_12_month"
## [1] "frail_descrip"

5 Step Four - Using Rsample to create ML data partitions

The Rsample package makes it easy to divide your data up. To view all the functionality navigate to the Rsample vignette.

We will divide the data into a training and test sample. This approach is the simplest method to testing your models accuracy and future performance on unseen data. Here we are going to treat the test data as the unseen data to allow us to evaluate if the model is fit for being released into the wild, or not.

# Partition into training and hold out test / validation sample
set.seed(123)
split <- rsample::initial_split(strand_pat, prop=3/4)
train_data <- rsample::training(split)
test_data <- rsample::testing(split)

6 Step Five - Creating your first Tidy Recipe

Recipes is an excellent package. I have for years done feature, dummy and other types of coding and feature selection with CARET, also a great package, but this makes the process much simpiler. The first part of the recipe is to fit your model and then you add recipe steps, this is supposed to replicate baking adding the specific ingredients. For all the particular steps that recipes contains, go directly to the recipes site.

stranded_rec <- 
  recipe(stranded_class ~ ., data=train_data) %>% 
  # The stranded class is what we are trying to predict and we are using the training data
  step_date(admit_date, features = c("dow", "month")) %>% 
  #Recipes step_date allows for additional features to be created from the date 
  step_rm(admit_date) %>% 
  #Remove the date, as we have created features off of it, if left in the dreaded multicolinearity may be present
  themis::step_upsample(stranded_class, over_ratio = as.numeric(upsample_ratio)) %>%  
  #SMOTE recipe step to upsample the minority class i.e. stranded patients
  step_dummy(all_nominal(), -all_outcomes()) %>% 
  #Automatically created dummy variables for all categorical variables (nominal)
  step_zv(all_predictors()) %>% 
  #Get rid of features that have zero variance
  step_normalize(all_predictors()) #ML models train better when the data is centered and scaled

print(stranded_rec) #Terminology is to use recipe
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          8
## 
## Operations:
## 
## Date features from admit_date
## Delete terms admit_date
## Up-sampling based on stranded_class
## Dummy variables from all_nominal(), -all_outcomes()
## Zero variance filter on all_predictors()
## Centering and scaling for all_predictors()

To look up some of these steps, I have previously covered them in a CARET tutorial. For all the list of recipes steps refer to the link above the code chunk.

7 Step Six - Get Parsnipping

The package Parsnip is the model to work with TidyModels. Parsnip still does not have many of the algorithms present in CARET, but it makes it much simpler to work in the tidy way.

Here we will create a basic logistic regression as our baseline model. If you want a second tutorial around model ensembling in TidyModels with Baguette and Stacks, then I would be happy to arrange this, but these are a session in themselves.

The reason Logistic Regression is the choice as it is a nice generalised linear model that most people have encountered.

TidyModels has a workflow structure which we will build in the next few steps:

7.1 Instantiate the model

In TidyModels you have to create an instance of the model in memory before working with it:

lr_mod <- 
  parsnip::logistic_reg() %>% 
  set_engine("glm")

print(lr_mod)
## Logistic Regression Model Specification (classification)
## 
## Computational engine: glm

The next step is to create the model workflow.

7.2 Creating the model workflow

Now it is time to do the workflow to connect the newly instantiated model together:

# Create model workflow
strand_wf <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(stranded_rec)

print(strand_wf)
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
## 
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
## 
## -- Model -----------------------------------------------------------------------
## Logistic Regression Model Specification (classification)
## 
## Computational engine: glm

7.3 Fitting the workflow to our data

The next step is fitting the model to our data:

# Create the model fit
strand_fit <- 
  strand_wf %>% 
  fit(data = train_data)

7.4 Extracting the fitted data

The final step is to use the pull_workflow_fit() parameter to retrieve the fit on the workflow:

strand_fitted <- strand_fit %>% 
  pull_workflow_fit() %>% 
  tidy()

print(strand_fitted)
## # A tibble: 18 x 5
##    term                               estimate std.error statistic  p.value
##    <chr>                                 <dbl>     <dbl>     <dbl>    <dbl>
##  1 (Intercept)                        -0.242       0.172   -1.41   1.60e- 1
##  2 age                                 0.296       0.261    1.13   2.58e- 1
##  3 care_home_ref_flag                  0.204       0.115    1.78   7.57e- 2
##  4 medically_safe_flag                -0.173       0.120   -1.45   1.48e- 1
##  5 hcop_flag                          -0.0443      0.114   -0.390  6.97e- 1
##  6 needs_mental_health_support_flag    0.0646      0.116    0.558  5.77e- 1
##  7 previous_care_in_last_12_month      2.98        0.477    6.24   4.48e-10
##  8 frail_descrip_Fall.patient.history -0.185       0.150   -1.23   2.18e- 1
##  9 frail_descrip_Mobility.problems     0.0864      0.138    0.625  5.32e- 1
## 10 frail_descrip_No.index.item         0.144       0.279    0.517  6.05e- 1
## 11 admit_date_dow_Mon                 -0.149       0.169   -0.884  3.77e- 1
## 12 admit_date_dow_Tue                  0.0955      0.154    0.620  5.36e- 1
## 13 admit_date_dow_Wed                  0.232       0.162    1.44   1.51e- 1
## 14 admit_date_dow_Thu                  0.175       0.147    1.19   2.34e- 1
## 15 admit_date_dow_Fri                  0.0203      0.165    0.123  9.02e- 1
## 16 admit_date_dow_Sat                  0.181       0.150    1.20   2.29e- 1
## 17 admit_date_month_Feb                0.0144      0.133    0.108  9.14e- 1
## 18 admit_date_month_Dec                0.00950     0.122    0.0778 9.38e- 1

7.5 Create custom plot to visualise significance utilising p values

As an optional step I have created a plot to visualise the significance. This will only work with linear, and generalized linear models, that analyse p values from t tests and finding the probability value from the t distribution. The visualisation code is contained hereunder:

# Add significance column to tibble using mutate
strand_fitted <- strand_fitted  %>% 
  mutate(Significance = ifelse(p.value < 0.05, "Significant", "Insignificant")) %>% 
  arrange(desc(p.value)) 

#Create a ggplot object to visualise significance
plot <- strand_fitted %>% 
  ggplot(data = strand_fitted, mapping = aes(x=term, y=p.value, fill=Significance)) +
  geom_col() + theme(axis.text.x = element_text(
                                        face="bold", color="#0070BA", 
                                        size=8, angle=90)
                                                ) + labs(y="P value", x="Terms", 
                                                         title="P value significance chart",
                                                         subtitle="A chart to represent the significant variables in the model",
                                                         caption="Produced by Gary Hutson")

#print("Creating plot of P values")
#print(plot)
plotly::ggplotly(plot)
#print(ggplotly(plot))
#ggsave("Figures/p_val_plot.png", plot) #Save the plot

8 Step Seven - Predicting with the holdout (test) dataset

Now we will assess how well the model predicts on the test (holdout) data to evaluate if we want to productionise the model, or abandon it at this stage. This is implemented below:

class_pred <- predict(strand_fit, test_data) #Get the class label predictions
prob_pred <- predict(strand_fit, test_data, type="prob") #Get the probability predictions
lr_predictions <- data.frame(class_pred, prob_pred) %>% 
  setNames(c("LR_Class", "LR_NotStrandedProb", "LR_StrandedProb")) #Combined into tibble and rename

stranded_preds <- test_data %>% 
  bind_cols(lr_predictions)

print(tail(lr_predictions))
##         LR_Class LR_NotStrandedProb LR_StrandedProb
## 169 Not Stranded          0.8280968       0.1719032
## 170 Not Stranded          0.7650720       0.2349280
## 171 Not Stranded          0.8256528       0.1743472
## 172 Not Stranded          0.8405598       0.1594402
## 173 Not Stranded          0.8371738       0.1628262
## 174 Not Stranded          0.8501583       0.1498417

9 Step Eight - Evaluate the model fit with Yardstick and CARET (Confusion Matrices)

Yardstick is another tool in the TidyModels arsenal. It is useful for generating quick summary statistics and evaluation metrics. I will grab the area under the curve estimates to show how well the model fits:

roc_plot <- 
  stranded_preds %>% 
  roc_curve(truth = stranded_class, LR_NotStrandedProb) %>% 
  autoplot

print(roc_plot)

stranded_preds %>% 
  roc_auc(truth = stranded_class, LR_NotStrandedProb)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary          0.75
# Get area under the curve value - accuracy

I like ROC plots - but they only show you sensitivity how well it is at predicting stranded and the inverse how good it is at predicting not stranded. I like to look at the overall accuracy and balanced accuracy on a confusion matrix, for binomial classification problems.

I use the CARET package and utilise the confusion matrix functions to perform this:

library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following objects are masked from 'package:yardstick':
## 
##     precision, recall, sensitivity, specificity
## The following object is masked from 'package:purrr':
## 
##     lift
cm <- caret::confusionMatrix(stranded_preds$stranded_class,
                       stranded_preds$LR_Class, 
                       positive="Stranded")

print(cm)
## Confusion Matrix and Statistics
## 
##               Reference
## Prediction     Not Stranded Stranded
##   Not Stranded          106        2
##   Stranded               35       31
##                                          
##                Accuracy : 0.7874         
##                  95% CI : (0.719, 0.8456)
##     No Information Rate : 0.8103         
##     P-Value [Acc > NIR] : 0.8093         
##                                          
##                   Kappa : 0.4998         
##                                          
##  Mcnemar's Test P-Value : 1.435e-07      
##                                          
##             Sensitivity : 0.9394         
##             Specificity : 0.7518         
##          Pos Pred Value : 0.4697         
##          Neg Pred Value : 0.9815         
##              Prevalence : 0.1897         
##          Detection Rate : 0.1782         
##    Detection Prevalence : 0.3793         
##       Balanced Accuracy : 0.8456         
##                                          
##        'Positive' Class : Stranded       
## 

9.1 Using ConfusionTableR package to visualise and flatten confusion matrix results

On the back of the Advanced Modelling course I did for the NHS-R Community I have created a package to work with the outputs of a confusion matrix. This package is aimed at the flattening of binary and multi-class confusion matrix results.

To load in the package you need to use the remotes package and bring in the ConfusionTableR package, which is available from my GitHub site.

#Load in my ConfusionTableR package to visualise this
#remotes::install_github("https://github.com/StatsGary/ConfusionTableR") #Use remotes package to install the package 
#from GitHub r
library(ConfusionTableR)
cm_plot <- ConfusionTableR::binary_visualiseR(cm, class_label1 = "Not Stranded", 
                     class_label2 = "Stranded",
                     quadrant_col1 = "#53BFD3", quadrant_col2 = "#006838", 
                     text_col = "white", custom_title = "Stranded patient Confusion Matrix")

# Flatten to store in database
#Stored confusion matrix

cm_results <- ConfusionTableR::binary_class_cm(cm)
print(cm_results)
##   Pred_Not.Stranded_Ref_Not.Stranded Pred_Stranded_Ref_Not.Stranded
## 1                                106                             35
##   Pred_Not.Stranded_Ref_Stranded Pred_Stranded_Ref_Stranded  Accuracy     Kappa
## 1                              2                         31 0.7873563 0.4997669
##   AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue
## 1     0.7190153     0.8456409    0.8103448      0.8092972  1.434553e-07
##   Sensitivity Specificity Pos.Pred.Value Neg.Pred.Value Precision    Recall
## 1   0.9393939    0.751773       0.469697      0.9814815  0.469697 0.9393939
##          F1 Prevalence Detection.Rate Detection.Prevalence Balanced.Accuracy
## 1 0.6262626  0.1896552      0.1781609            0.3793103         0.8455835
##                 cm_ts
## 1 2021-02-24 10:55:48

The next markdown document will look at how to improve your models with model selection, K-fold cross validation and hyperparameter tuning. I was thinking of doing an ensembling course off the back of this, so please contact me if that would be interesting to you.

10 Save the data for consumption in the next tutorials

I will now save the R image data into file, as we will pick this up in the next markdown document.

save.image(file="Data/stranded_data.rdata")

11 Resuming where we left off in the first Markdown document

The first markdown document showed you how to build your first TidyModels model on an healthcare dataset. This could be a ML model you simply tweak for your own uses. I will now load the data back in and resume where we left off:

load(file="Data/stranded_data.rdata")

12 Improve the model with resampling with the Rsample package

The first step will involve something called cross validation (see supporting workshop slides). The essence of cross validation is that you take sub samples of the training dataset. This is done to emulate how well the model will perform on unseen data samples when out in the wild (production):

As the image shows - the folds take a sampe of the training set and each randomly selected fold acts as the test sample. We then use a final hold out validation set to finally test the model. This will be shown in the following section.

set.seed(123)
#Set a random seed for replication of results
ten_fold <- vfold_cv(train_data, v=10)

12.1 Use previous workflow with cross validation

We will use the previous trained logistic regression model with resamples to improve the results of the cross validation:

set.seed(123)
lr_fit_rs <- 
  strand_wf %>% 
  fit_resamples(ten_fold)
## 
## Attaching package: 'rlang'
## The following object is masked from 'package:magrittr':
## 
##     set_names
## The following objects are masked from 'package:purrr':
## 
##     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
##     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
##     splice
## 
## Attaching package: 'vctrs'
## The following object is masked from 'package:tibble':
## 
##     data_frame
## The following object is masked from 'package:dplyr':
## 
##     data_frame

We will now collect the metrics using the tune package and the collect_metrics function:

# To collect the resmaples you need to call collect_metrics to average out the accuracy for that model
collected_mets <- tune::collect_metrics(lr_fit_rs)
print(collected_mets)
## # A tibble: 2 x 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.754    10  0.0210 Preprocessor1_Model1
## 2 roc_auc  binary     0.712    10  0.0247 Preprocessor1_Model1
# Now I can compare the accuracy from the previous test set I had already generated a confusion matrix for
accuracy_resamples <- collected_mets$mean[1] * 100
accuracy_validation_set <- as.numeric(cm$overall[1] * 100)
print(cat(paste0("The true accuracy of the model is between the resample testing:", 
            round(accuracy_resamples,2), "\nThe validation sample: ",
            round(accuracy_validation_set,2), ".")))
## The true accuracy of the model is between the resample testing:75.42
## The validation sample: 78.74.NULL

This shows that the true accuracy value is somewhere between the reported results from the resampling method and those in our validation sample.

13 Improve the model with different model selection and resampling

The following example will move on from the logistic regression and aim to build a random forest, and later a decision tree. Other options in Parnsip would be to use a gradient boosted tree to amp up the results further. In addition, I aim at teaching a follow up webinar to this for ensembling - specifically model stacking (Stacks package) and bagging (Baguette package).

13.1 Define and instantiate the model

The first step, as with the logistic regression example, if to define and instantiate the model:

rf_mod <- 
  rand_forest(trees=500) %>% 
  set_engine("ranger") %>% 
  set_mode("classification")

print(rf_mod)
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   trees = 500
## 
## Computational engine: ranger

13.2 Fit the model to the previous training data

Then we are going to fit the model to the previous training data:

rf_fit <- 
  rf_mod %>% 
  fit(stranded_class ~ ., data = train_data)

print(rf_fit)
## parsnip model object
## 
## Fit time:  130ms 
## Ranger result
## 
## Call:
##  ranger::ranger(x = maybe_data_frame(x), y = y, num.trees = ~500,      num.threads = 1, verbose = FALSE, seed = sample.int(10^5,          1), probability = TRUE) 
## 
## Type:                             Probability estimation 
## Number of trees:                  500 
## Sample size:                      525 
## Number of independent variables:  8 
## Mtry:                             2 
## Target node size:                 10 
## Variable importance mode:         none 
## Splitrule:                        gini 
## OOB prediction error (Brier s.):  0.1734941

13.3 Improve further by fitting to resamples

We will aim to increase the sample representation in this model by fitting it to a resamples object, in parsnip and rsample:

#Create workflow step
rf_wf <- 
  workflow() %>% 
  add_model(rf_mod) %>% 
  add_formula(stranded_class ~ .) #The predictor is contained in add_formula method

set.seed(123)
rf_fit_rs <- 
  rf_wf %>% 
  fit_resamples(ten_fold)

print(rf_fit_rs)
## # Resampling results
## # 10-fold cross-validation 
## # A tibble: 10 x 4
##    splits            id     .metrics         .notes          
##    <list>            <chr>  <list>           <list>          
##  1 <rsplit [472/53]> Fold01 <tibble [2 x 4]> <tibble [0 x 1]>
##  2 <rsplit [472/53]> Fold02 <tibble [2 x 4]> <tibble [0 x 1]>
##  3 <rsplit [472/53]> Fold03 <tibble [2 x 4]> <tibble [0 x 1]>
##  4 <rsplit [472/53]> Fold04 <tibble [2 x 4]> <tibble [0 x 1]>
##  5 <rsplit [472/53]> Fold05 <tibble [2 x 4]> <tibble [0 x 1]>
##  6 <rsplit [473/52]> Fold06 <tibble [2 x 4]> <tibble [0 x 1]>
##  7 <rsplit [473/52]> Fold07 <tibble [2 x 4]> <tibble [0 x 1]>
##  8 <rsplit [473/52]> Fold08 <tibble [2 x 4]> <tibble [0 x 1]>
##  9 <rsplit [473/52]> Fold09 <tibble [2 x 4]> <tibble [0 x 1]>
## 10 <rsplit [473/52]> Fold10 <tibble [2 x 4]> <tibble [0 x 1]>

13.4 Collect the resampled metrics

The next step is to collect the resample metrics:

# Collect the metrics using another model with resampling
rf_resample_mean_preds <- tune::collect_metrics(rf_fit_rs)
print(rf_resample_mean_preds)
## # A tibble: 2 x 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.787    10  0.0156 Preprocessor1_Model1
## 2 roc_auc  binary     0.723    10  0.0245 Preprocessor1_Model1

The model predictive power is maxing out at about 78%. I know this is due to the fact that the data is dummy data and most of the features that are contained in the model have a weak association to the outcome variable.

What you would need to do after this is look for more representative features of what causes a patient to stay a long time in hospital. This is where the clinical context comes into play.

14 Improve the model with hyperparameter tuning with the Dials package

We are going to now create a decision tree and we are going to tune the hyperparameters using the dials package. The dials package contains a list of hyperparameter tuning methods and is useful for creating quick hyperparameter grids and aiming to optimise them.

14.1 Building the decision tree

Like all the other steps, the first thing to do is build the decision tree. Note - the reason set_model(“classification”) is because the thing we are predicting is a factor. If this was a continuous variable, then you would need to switch this to regression. However, the model development for regression is identical to classification.

tune_tree <- 
  decision_tree(
    cost_complexity = tune(), #tune() is a placeholder for an empty grid 
    tree_depth = tune() #we will fill these in the next section
  ) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

print(tune_tree)
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = tune()
##   tree_depth = tune()
## 
## Computational engine: rpart

14.3 Setting up parallel processing

The tuning process, and modelling process, normally needs the ML engineer to access the full potential of your machine. The next steps show how to register the cores on your machine and max them out for training the model and doing grid searching:

all_cores <- parallel::detectCores(logical = FALSE)-1
print(all_cores)
## [1] 3
#Registers all cores and subtracts one, so you have some time to work
cl <- makePSOCKcluster(all_cores)
print(cl)
## socket cluster with 3 nodes on host 'localhost'
#Makes an in memory cluster to utilise your cores
registerDoParallel(cl)
#Registers that we want to do parallel processing

14.4 Creating the model workflow

Next, I will create the model workflow, as we have done a few times before:

set.seed(123)
tree_wf <- workflow() %>% 
  add_model(tune_tree) %>% 
  add_formula(stranded_class ~ .)
# Make the decision tree workflow - always postfix with wf for convention
# Add the registered model
# Add the formula of the outcome class you are predicting against all IVs

tree_pred_tuned <- 
  tree_wf %>% 
  tune::tune_grid(
    resamples = ten_fold, #This is the 10 fold cross validation variable we created earlier
    grid = grid_tree_tune #This is the tuning grid
  )

14.5 Visualise the tuning process

This ggplot helps to visualise how the manual tuning has gone on and will show where the best tree depth occurs in terms of the cost complexity (the number of terminal or leaf nodes):

tune_plot <- tree_pred_tuned %>%
  collect_metrics() %>% #Collect metrics from tuning
  mutate(tree_depth = factor(tree_depth)) %>%
  ggplot(aes(cost_complexity, mean, color = tree_depth)) +
  geom_line(size = 1, alpha = 0.7) +
  geom_point(size = 1.5) +
  facet_wrap(~ .metric, scales = "free", nrow = 2) +
  scale_x_log10(labels = scales::label_number()) +
  scale_color_viridis_d(option = "plasma", begin = .9, end = 0) + theme_minimal()

print(tune_plot)

ggsave(filename="Figures/hyperparameter_tree.png", tune_plot)
## Saving 7 x 5 in image

This shows that you only need a depth of 4 to get the optimal accuracy. However, the tune package helps us out with this as well.

14.6 Selecting the best model from the tuning process with Tune

The tune package allows us to select the best candidate model, with the most optimal set of hyperparameters:

# To get the best ROC - area under the curve value we will use the following:
tree_pred_tuned %>% 
  tune::show_best("roc_auc")
## # A tibble: 5 x 8
##   cost_complexity tree_depth .metric .estimator  mean     n std_err .config     
##             <dbl>      <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>       
## 1    0.0000000001          4 roc_auc binary     0.707    10  0.0150 Preprocesso~
## 2    0.000000001           4 roc_auc binary     0.707    10  0.0150 Preprocesso~
## 3    0.00000001            4 roc_auc binary     0.707    10  0.0150 Preprocesso~
## 4    0.0000001             4 roc_auc binary     0.707    10  0.0150 Preprocesso~
## 5    0.000001              4 roc_auc binary     0.707    10  0.0150 Preprocesso~
# Select the best tree
best_tree <- tree_pred_tuned %>% 
  tune::select_best("roc_auc")

print(best_tree)
## # A tibble: 1 x 3
##   cost_complexity tree_depth .config               
##             <dbl>      <int> <chr>                 
## 1    0.0000000001          4 Preprocessor1_Model021

The next step is to us the best tree to make our predictions.

14.7 Using best tree to make predictions

final_wf <- 
  tree_wf %>% 
  finalize_workflow(best_tree) #Finalise workflow passes in our best tree

print(final_wf)
## == Workflow ====================================================================
## Preprocessor: Formula
## Model: decision_tree()
## 
## -- Preprocessor ----------------------------------------------------------------
## stranded_class ~ .
## 
## -- Model -----------------------------------------------------------------------
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = 1e-10
##   tree_depth = 4
## 
## Computational engine: rpart

Make a prediction against this finalised tree:

final_tree_pred <- 
  final_wf %>% 
  fit(data = train_data)

print(final_tree_pred)
## == Workflow [trained] ==========================================================
## Preprocessor: Formula
## Model: decision_tree()
## 
## -- Preprocessor ----------------------------------------------------------------
## stranded_class ~ .
## 
## -- Model -----------------------------------------------------------------------
## n= 525 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 525 175 Not Stranded (0.6666667 0.3333333)  
##    2) previous_care_in_last_12_month< 1.5 439  99 Not Stranded (0.7744875 0.2255125)  
##      4) frail_descrip=Activity Limitation,Fall patient history 158  28 Not Stranded (0.8227848 0.1772152) *
##      5) frail_descrip=Mobility problems,No index item 281  71 Not Stranded (0.7473310 0.2526690)  
##       10) age< 34.5 128  23 Not Stranded (0.8203125 0.1796875) *
##       11) age>=34.5 153  48 Not Stranded (0.6862745 0.3137255)  
##         22) age>=37.5 139  40 Not Stranded (0.7122302 0.2877698) *
##         23) age< 37.5 14   6 Stranded (0.4285714 0.5714286) *
##    3) previous_care_in_last_12_month>=1.5 86  10 Stranded (0.1162791 0.8837209) *

14.8 Use VIP package to visualise variable importance

We will look at global variable importance. As mentioned prior, to look at local patient level importance, use the LIME package.

plot <- final_tree_pred %>% 
  pull_workflow_fit() %>% 
  vip(aesthetics = list(color = "black", fill = "#26ACB5")) + theme_minimal()

print(plot)

ggsave("Figures/VarImp.png", plot)
## Saving 7 x 5 in image

This was derived when we looked at the logistic regression significance that these would be the important variables, due to their linear significance.

14.9 Create the final predictions

The last step is to create the final predictions from the tuned decision tree:

# Create the final prediction
final_fit <- 
  final_wf %>% 
  last_fit(split)

final_fit_fitted_metrics <- final_fit %>% 
  collect_metrics() 

print(final_fit_fitted_metrics)
## # A tibble: 2 x 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.770 Preprocessor1_Model1
## 2 roc_auc  binary         0.755 Preprocessor1_Model1
#Create the final predictions
final_fit_predictions <- final_fit %>% 
  collect_predictions()
print(final_fit_predictions)
## # A tibble: 174 x 7
##    id         `.pred_Not Strand~ .pred_Stranded  .row .pred_class stranded_class
##    <chr>                   <dbl>          <dbl> <int> <fct>       <fct>         
##  1 train/tes~              0.820          0.180     5 Not Strand~ Not Stranded  
##  2 train/tes~              0.712          0.288    13 Not Strand~ Stranded      
##  3 train/tes~              0.712          0.288    14 Not Strand~ Not Stranded  
##  4 train/tes~              0.823          0.177    20 Not Strand~ Not Stranded  
##  5 train/tes~              0.712          0.288    23 Not Strand~ Not Stranded  
##  6 train/tes~              0.712          0.288    26 Not Strand~ Stranded      
##  7 train/tes~              0.823          0.177    34 Not Strand~ Not Stranded  
##  8 train/tes~              0.712          0.288    39 Not Strand~ Stranded      
##  9 train/tes~              0.712          0.288    40 Not Strand~ Stranded      
## 10 train/tes~              0.823          0.177    41 Not Strand~ Not Stranded  
## # ... with 164 more rows, and 1 more variable: .config <chr>

14.10 Visualise the final fit on a ROC curve

You could do similar with viewing this object in the confusion matrix add in, but I will view this on a plot:

roc_plot <- final_fit_predictions %>% 
  roc_curve(stranded_class, `.pred_Not Stranded`) %>% 
  autoplot() 

print(roc_plot)

ggsave(filename = "Figures/tuned_tree.png", plot=roc_plot)
## Saving 7 x 5 in image

15 Inspecting any Parsnip object

One last point to note - to inspect any of the tuning parameters and hyperparameters for the models you can use the args function to return these - examples below:

args(decision_tree)
## function (mode = "unknown", cost_complexity = NULL, tree_depth = NULL, 
##     min_n = NULL) 
## NULL
args(logistic_reg)
## function (mode = "classification", penalty = NULL, mixture = NULL) 
## NULL
args(rand_forest)
## function (mode = "unknown", mtry = NULL, trees = NULL, min_n = NULL) 
## NULL

16 Ensembling

The next section shows how to use the stacks package to ensemble our decision tree, logistic regression and random forest together. This uses the stacks package.

16.1 Create the models

We will use the recipe created for the first model we created for the models. The next step is to initialise the models.

lr_model <- logistic_reg() %>% 
  set_mode("classification") %>% 
  set_engine("glm")

rf_model <- rand_forest() %>% 
  set_mode("classification") %>% 
  set_engine("ranger")

xg_boost_model <- boost_tree() %>% 
  set_mode("classification") %>% 
  set_engine("xgboost")

nn_model <- mlp(epochs=300, hidden_units = 5, dropout = 0.5) %>% 
  set_mode("classification") %>% 
  set_engine("keras", verbose=0)

neighbour_model <- nearest_neighbor() %>% 
  set_engine("kknn") %>% 
  set_mode("classification")

# You could tune hyperparameters here, see previous step, for simplicity I am 
# just instantiating the models

16.2 Define the model workflows

Each of the models we have created will need its own worflow, so we will create one for each. The general rule in parsnip is for every model, create its own workflow.

lr_wf <- workflow() %>% 
  add_model(lr_model) %>% 
  add_recipe(stranded_rec) #Use the stranded recipe we created once at the top

rf_wf <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(stranded_rec)

xgboost_wf <- workflow() %>% 
  add_model(xg_boost_model) %>% 
  add_recipe(stranded_rec)

nn_wf <- workflow() %>% 
  add_model(nn_model) %>% 
  add_recipe(stranded_rec)

neighbour_wf <- workflow() %>% 
  add_model(neighbour_model) %>% 
  add_recipe(stranded_rec)

print(lr_wf)
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: logistic_reg()
## 
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
## 
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
## 
## -- Model -----------------------------------------------------------------------
## Logistic Regression Model Specification (classification)
## 
## Computational engine: glm
print(rf_wf)
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: rand_forest()
## 
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
## 
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
## 
## -- Model -----------------------------------------------------------------------
## Random Forest Model Specification (classification)
## 
## Computational engine: ranger
print(xgboost_wf)
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: boost_tree()
## 
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
## 
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
## 
## -- Model -----------------------------------------------------------------------
## Boosted Tree Model Specification (classification)
## 
## Computational engine: xgboost
print(nn_wf)
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: mlp()
## 
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
## 
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
## 
## -- Model -----------------------------------------------------------------------
## Single Layer Neural Network Specification (classification)
## 
## Main Arguments:
##   hidden_units = 5
##   dropout = 0.5
##   epochs = 300
## 
## Engine-Specific Arguments:
##   verbose = 0
## 
## Computational engine: keras
print(neighbour_wf)
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: nearest_neighbor()
## 
## -- Preprocessor ----------------------------------------------------------------
## 6 Recipe Steps
## 
## * step_date()
## * step_rm()
## * step_upsample()
## * step_dummy()
## * step_zv()
## * step_normalize()
## 
## -- Model -----------------------------------------------------------------------
## K-Nearest Neighbor Model Specification (classification)
## 
## Computational engine: kknn

16.3 Fit the models using K-Fold Cross Validation

We will now fit our 3 candidates models using resampling, as this is the defacto ML method for better representation in the underlying training data:

model_control <- control_grid(save_pred = TRUE, save_workflow = TRUE)
registerDoParallel(cl)
#Register cluster for parallel processing

system.time(lr_fit <- fit_resamples(
  lr_wf,
  resamples = ten_fold,
  control = model_control
))
## The workflow being saved contains a recipe, which is 0.09 Mb in memory. If this was not intentional, please set the control setting `save_workflow = FALSE`.
##    user  system elapsed 
##    0.12    0.01    3.09
system.time(rf_fit <- fit_resamples(
  rf_wf,
  resamples = ten_fold,
  control = model_control
))
## The workflow being saved contains a recipe, which is 0.09 Mb in memory. If this was not intentional, please set the control setting `save_workflow = FALSE`.
##    user  system elapsed 
##    0.09    0.02    3.50
system.time(xgboost_fit <- fit_resamples(
  xgboost_wf,
  resamples = ten_fold,
   control = model_control
))
## The workflow being saved contains a recipe, which is 0.09 Mb in memory. If this was not intentional, please set the control setting `save_workflow = FALSE`.
##    user  system elapsed 
##    0.13    0.03    3.22
system.time(nn_fit <- fit_resamples(
  nn_wf, 
  resamples = ten_fold, 
  control = model_control
))
## ! Some required packages prohibit parallel processing:  'keras'
## 
## Attaching package: 'keras'
## The following object is masked from 'package:yardstick':
## 
##     get_weights
## The workflow being saved contains a recipe, which is 0.09 Mb in memory. If this was not intentional, please set the control setting `save_workflow = FALSE`.
##    user  system elapsed 
##  114.07    4.26   88.79
system.time(neighbours_fit <- fit_resamples(
  neighbour_wf, 
  resamples = ten_fold, 
  control = model_control
))
## The workflow being saved contains a recipe, which is 0.09 Mb in memory. If this was not intentional, please set the control setting `save_workflow = FALSE`.
##    user  system elapsed 
##    0.11    0.07    3.24

16.4 Create a stacking ensemble

We will now load the stacks package to create a meta model of the decision tree, random forest and logistic regression classifiers.

library(stacks)
meta_stacked_model <- stacks() %>% 
  add_candidates(lr_fit) %>% 
  add_candidates(rf_fit) %>% 
  add_candidates(xgboost_fit) %>% 
  add_candidates(nn_fit) %>% 
  add_candidates(neighbours_fit)

print(meta_stacked_model)
## # A data stack with 5 model definitions and 5 candidate members:
## #   lr_fit: 1 model configuration
## #   rf_fit: 1 model configuration
## #   xgboost_fit: 1 model configuration
## #   nn_fit: 1 model configuration
## #   neighbours_fit: 1 model configuration
## # Outcome: stranded_class (factor)

16.5 The stacks model as a tibble

Really, the stacks model is a fancy tibble. To view it as such, use the casting function to convert to a tibble:

as_tibble(meta_stacked_model)
## # A tibble: 525 x 11
##    stranded_class `.pred_Not Stranded_~ .pred_Stranded_lr_~ `.pred_Not Stranded~
##    <fct>                          <dbl>               <dbl>                <dbl>
##  1 Not Stranded                 0.766                0.234                 0.707
##  2 Not Stranded                 0.813                0.187                 0.854
##  3 Not Stranded                 0.722                0.278                 0.441
##  4 Not Stranded                 0.502                0.498                 0.769
##  5 Stranded                     0.581                0.419                 0.601
##  6 Not Stranded                 0.855                0.145                 0.881
##  7 Not Stranded                 0.920                0.0801                0.869
##  8 Not Stranded                 0.00154              0.998                 0.205
##  9 Not Stranded                 0.736                0.264                 0.635
## 10 Not Stranded                 0.901                0.0991                0.930
## # ... with 515 more rows, and 7 more variables:
## #   .pred_Stranded_rf_fit_1_1 <dbl>, .pred_Not Stranded_xgboost_fit_1_1 <dbl>,
## #   .pred_Stranded_xgboost_fit_1_1 <dbl>, .pred_Not Stranded_nn_fit_1_1 <dbl>,
## #   .pred_Stranded_nn_fit_1_1 <dbl>,
## #   .pred_Not Stranded_neighbours_fit_1_1 <dbl>,
## #   .pred_Stranded_neighbours_fit_1_1 <dbl>

This shows that it generates probability predictions for the outcome variable, by each model type.

##Blending prediction This stacks the model, now we will blend the predictions from the models to use in our meta model:

meta_stacked_model <- meta_stacked_model %>% 
  blend_predictions()
print(meta_stacked_model)
## -- A stacked ensemble model -------------------------------------
## 
## Out of 5 possible candidate members, the ensemble retained 2.
## Lasso penalty: 0.001.
## 
## The 2 highest weighted member classes are:
## # A tibble: 2 x 3
##   member                    type         weight
##   <chr>                     <chr>         <dbl>
## 1 .pred_Stranded_rf_fit_1_1 rand_forest    2.40
## 2 .pred_Stranded_lr_fit_1_1 logistic_reg   2.14
## 
## Members have not yet been fitted with `fit_members()`.

The blend_predictions function determines how member model output will ultimately be combined in the final prediction by fitting a LASSO model on the data stack, predicting the true assessment set outcome using the predictions from each of the candidate members. Candidates with nonzero stacking coefficients become members.

Now, we will use autoplot on this to look at the trade-off between minimising the number of members, whilst optimising performance:

theme_set(theme_minimal())
autoplot(meta_stacked_model)

autoplot(meta_stacked_model, type="members")

autoplot(meta_stacked_model, type="weights")

## Fitting candidate models into a stack

Once we have analysed the diagnostics, and viewed the accuracy and roc, we can then use the fit_members() function to fit our candidate models:

registerDoParallel(cl)
#Register cluster for parallel processing
system.time(meta_stacked_model <- meta_stacked_model %>% 
  fit_members())
##    user  system elapsed 
##    0.71    1.03    4.86
print(meta_stacked_model)
## -- A stacked ensemble model -------------------------------------
## 
## Out of 5 possible candidate members, the ensemble retained 2.
## Lasso penalty: 0.001.
## 
## The 2 highest weighted member classes are:
## # A tibble: 2 x 3
##   member                    type         weight
##   <chr>                     <chr>         <dbl>
## 1 .pred_Stranded_rf_fit_1_1 rand_forest    2.40
## 2 .pred_Stranded_lr_fit_1_1 logistic_reg   2.14

Model stacks can be thought of as a group of fitted member models and a set of instructions on how to combine their predictions.

16.6 Using stacked model to predict test data

The next stage is to make the predictions on the test hold out set, as we have used resampling to generate different samples on the training data:

test_data <- 
  test_data %>% 
  bind_cols(predict(meta_stacked_model, .),
            predict(meta_stacked_model, ., type="prob")) #Expose the prediction probabilities

print(test_data)
## # A tibble: 174 x 12
##    stranded_class   age care_home_ref_flag medically_safe_flag hcop_flag
##    <fct>          <dbl>              <dbl>               <dbl>     <dbl>
##  1 Not Stranded      33                  0                   0         1
##  2 Stranded          80                  1                   1         0
##  3 Not Stranded      72                  1                   0         0
##  4 Not Stranded      75                  1                   0         0
##  5 Not Stranded      68                  0                   1         1
##  6 Stranded          41                  0                   1         1
##  7 Not Stranded      78                  0                   1         0
##  8 Stranded          60                  1                   0         0
##  9 Stranded          56                  1                   0         0
## 10 Not Stranded      70                  0                   0         0
## # ... with 164 more rows, and 7 more variables:
## #   needs_mental_health_support_flag <dbl>,
## #   previous_care_in_last_12_month <dbl>, admit_date <date>,
## #   frail_descrip <chr>, .pred_class <fct>, .pred_Not Stranded <dbl>,
## #   .pred_Stranded <dbl>

To visualise the distribution of ensemble model predictions vs actual stranded labels we will visualise this on a scatter chart:

16.7 Assess meta model with confusion matrix

We will now assess the hold out sample with the confusion matrix:

cm <- caret::confusionMatrix(test_data$stranded_class,
                       test_data$.pred_class, 
                       positive="Stranded")

print(cm)
## Confusion Matrix and Statistics
## 
##               Reference
## Prediction     Not Stranded Stranded
##   Not Stranded          105        3
##   Stranded               34       32
##                                          
##                Accuracy : 0.7874         
##                  95% CI : (0.719, 0.8456)
##     No Information Rate : 0.7989         
##     P-Value [Acc > NIR] : 0.687          
##                                          
##                   Kappa : 0.503          
##                                          
##  Mcnemar's Test P-Value : 8.14e-07       
##                                          
##             Sensitivity : 0.9143         
##             Specificity : 0.7554         
##          Pos Pred Value : 0.4848         
##          Neg Pred Value : 0.9722         
##              Prevalence : 0.2011         
##          Detection Rate : 0.1839         
##    Detection Prevalence : 0.3793         
##       Balanced Accuracy : 0.8348         
##                                          
##        'Positive' Class : Stranded       
## 

The ensemble performs well, especially in terms of sensitivitiy to predicting the stranded class, however I think class imbalance in this dataset would need to be adjusted further.

Visualising this using the ConfusionTableR package:

cm_plot <- ConfusionTableR::binary_visualiseR(cm, class_label1 = "Not Stranded", 
                     class_label2 = "Stranded",
                     quadrant_col1 = "#53BFD3", quadrant_col2 = "#006838", 
                     text_col = "white", custom_title = "Stranded patient Confusion Matrix")

16.8 Hyperparameter tuning

The next phase of this would be to undertake hyperparameter tuning using dials, with the meta model to further improve the accuracy. I am sensing the best approach, however, would be to find more variables in the system th are indicative of a patient becoming stranded.